'''
Proxy requests Analytics requests
'''

import time

from enum import Enum

from typing import Optional, Any

import requests

DEFAULT_OIDC_TOKEN_ENDPOINT = 'https://sso.redhat.com/auth/realms/redhat-external/protocol/openid-connect/token'


class TokenError(requests.RequestException):
    '''
    Raised when token generation request fails.

    Useful for differentiating request failure for make_request() vs.
    other requests issued to get a token i.e.:

    try:
      client = OIDCClient(...)
      client.make_request(...)
    except TokenError as e:
        print(f"Token generation failed due to {e.__cause__}")
    except requests.RequestException:
        print("API request failed)
    '''

    def __init__(self, message="Token generation request failed", response=None):
        super().__init__(message)
        self.response = response  # Store the response for debugging


def _now(reason: str):
    '''
    Wrapper for time. Helps with testing.
    '''
    return int(time.time())


class TokenType(Enum):
    '''
    Access token type as returned by the remote API.
    '''

    BEARER = 'Bearer'


class Token:
    '''
    Token data generated by OIDC response.
    '''

    access_token: str
    expires_in: int
    refresh_expires_in: int
    token_type: TokenType
    not_before_policy: int  # not-before-policy
    scope: str

    def __init__(
        self,
        access_token: str,
        expires_in: int,
        refresh_expires_in: int,
        token_type: TokenType,
        not_before_policy: int,
        scope: str,
    ):
        self.access_token = access_token
        self.expires_in = expires_in
        self.refresh_expires_in = refresh_expires_in
        self.token_type = token_type
        self.not_before_policy = not_before_policy
        self.scope = scope

        self._now = _now(reason='token-creation')

    @property
    def expires_at(self) -> int:
        '''
        Unix timestamp in seconds of when the token expires.
        '''
        return self._now + self.expires_in

    def is_expired(self) -> bool:
        '''
        Check if the token is expired.
        '''
        return _now(reason='token-expiration-check') >= self.expires_at


class OIDCClient:
    '''
    Wraps requests library make_request() and manages OIDC access token.
    '''

    def __init__(
        self,
        client_id: str,
        client_secret: str,
        token_url: str = DEFAULT_OIDC_TOKEN_ENDPOINT,
        scopes: list[str] = None,
        base_url: str = '',
    ) -> None:
        self.client_id: str = client_id
        self.client_secret: str = client_secret
        self.token_url: str = token_url
        if scopes is None:
            scopes = ['api.console']
        self.scopes = scopes
        self.base_url: str = base_url
        self.token: Optional[Token] = None

    @classmethod
    def _json_response_to_token(cls, json_response: Any) -> Token:
        return Token(
            access_token=json_response['access_token'],
            expires_in=json_response['expires_in'],
            refresh_expires_in=json_response['refresh_expires_in'],
            token_type=TokenType(json_response['token_type']),
            not_before_policy=json_response['not-before-policy'],
            scope=json_response['scope'],
        )

    def _generate_access_token(self) -> None:
        '''
        Fetches the initial access token using client credentials.
        '''
        response = requests.post(
            self.token_url,
            data={
                'grant_type': 'client_credentials',
                'client_id': self.client_id,
                'client_secret': self.client_secret,
                'scope': self.scopes,
            },
            headers={'Content-Type': 'application/x-www-form-urlencoded'},
        )
        try:
            response.raise_for_status()
        except requests.RequestException as e:
            raise TokenError() from e
        self.token = OIDCClient._json_response_to_token(response.json())

    def _add_headers(self, headers: dict[str, str]) -> None:
        '''
        Add token header
        '''
        headers.update(
            {
                'Authorization': f'Bearer {self.token.access_token}',
                'Accept': 'application/json',
            }
        )

    def _make_request(self, method: str, url: str, headers: dict[str, str], **kwargs: Any) -> requests.Response:
        '''
        Actually make an API call.
        '''
        self._add_headers(headers)
        return requests.request(method, url, headers=headers, **kwargs)

    def make_request(self, method: str, endpoint: str, **kwargs: Any) -> requests.Response:
        '''
        Makes an authenticated request and refreshes the token if expired.
        '''
        has_generated_token = False

        def generate_access_token():
            self._generate_access_token()
            return True

        if not self.token or self.token.is_expired():
            has_generated_token = generate_access_token()

        url = f'{self.base_url}{endpoint}'
        headers = kwargs.pop('headers', {})

        response = self._make_request(method, url, headers, **kwargs)
        if not has_generated_token and response.status_code == 401:
            generate_access_token()
            response = self._make_request(method, url, headers, **kwargs)

        return response
